﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Data;
using System.Data.Common;
using System.Data.Entity;
using System.Threading;
using System.Data.Entity.Core;
using System.Data.Entity.Core.Objects;
using System.Data.Entity.Infrastructure;
using Microsoft.AspNet.Identity.EntityFramework; 
using App.Core.Entities.Foundation;
using App.Core.Entities.Identity;
using App.Core.Entities.Vendor;
using App.Core.Entities.Dev;
using App.Core.Logging;
using App.Data.Configurations.Identity;
using App.Data.Configurations.Vendor;
using App.Data.Configurations.Dev;

namespace App.Data
{
    public class AppDbContext : IdentityDbContext<UserEntity,RoleEntity,int,UserLoginEntity,UserRoleEntity,UserClaimEntity>,IDbContext
    {
        private ObjectContext objectContext;
        private DbTransaction transaction;
        private static readonly object locker = new object();
        private static bool databaseInitialized;

        #region 构造函数
        public AppDbContext():base("Name=AppDbConnection")
        {

        }
        public AppDbContext(string nameOrConnectionString,ILogger logger)
            :base(nameOrConnectionString)
        {
            if (logger != null)
            {
                Database.Log = logger.Log;
            }
            if (databaseInitialized)
                return;
            lock (locker)
            {
                if (!databaseInitialized)
                {
                    Database.SetInitializer<AppDbContext>(null);
                    databaseInitialized = true;
                }
            }
        }
        #endregion

        /// <summary>
        /// 添加约束条件
        /// </summary>
        /// <param name="modelBuilder"></param>
        protected override void OnModelCreating(DbModelBuilder modelBuilder)
        {
            base.OnModelCreating(modelBuilder);

            #region Identity
            modelBuilder.Configurations.Add(new UserConfiguration());
            modelBuilder.Configurations.Add(new RoleConfiguration());
            modelBuilder.Configurations.Add(new MenuConfiguration());
            modelBuilder.Configurations.Add(new LogConfiguration());
            modelBuilder.Configurations.Add(new ModuleConfiguration());

            modelBuilder.Entity<UserClaimEntity>().ToTable("Sys_UserClaim");
            modelBuilder.Entity<UserLoginEntity>().ToTable("Sys_UserLogin");
            modelBuilder.Entity<UserRoleEntity>().ToTable("Sys_UserRole");
            #endregion

            #region Vendor

            modelBuilder.Configurations.Add(new VendorConfiguration());
            modelBuilder.Configurations.Add(new VendorAttachfileConfiguration());
            modelBuilder.Configurations.Add(new ConsultantConfiguration());
            modelBuilder.Configurations.Add(new ConsultantAttachfileConfiguration());

            #endregion

            #region DEV

            modelBuilder.Configurations.Add(new BugConfiguration());
            modelBuilder.Configurations.Add(new BugAssignConfiguration());
            modelBuilder.Configurations.Add(new BugHistoryConfiguration());
            modelBuilder.Configurations.Add(new BugAttachfileConfiguration());
            modelBuilder.Configurations.Add(new PublishConfiguration());

            #endregion
        }

        #region 重写保存修改方法

        public override int SaveChanges()
        {
            IEnumerable<DbEntityEntry> modifiedEntries = ChangeTracker.Entries()
                .Where(x => x.Entity is BaseEntity
                    && (x.State == System.Data.Entity.EntityState.Added || x.State == System.Data.Entity.EntityState.Modified));

            foreach (var entry in modifiedEntries)
            {
                var entity = entry.Entity as BaseEntity;
                if (entity != null)
                {
                    string identityName = Thread.CurrentPrincipal.Identity.Name;
                    DateTime now = DateTime.Now;

                    if (entry.State == System.Data.Entity.EntityState.Added)
                    {
                        entity.CreatedBy = identityName;
                        entity.CreatedDate = now;
                    }
                    else
                    {
                        base.Entry(entity).Property(x => x.CreatedBy).IsModified = false;
                        base.Entry(entity).Property(x => x.CreatedDate).IsModified = false;
                    }

                    entity.UpdatedBy = identityName;
                    entity.UpdatedDate = now;
                }
            }

            return base.SaveChanges();
        }

        #endregion


        /// <summary>
        /// 实例化
        /// </summary>
        /// <returns></returns>
        public static AppDbContext Create()
        {
            return new AppDbContext(nameOrConnectionString: "Name=AppDbConnection", logger:null);
        }

        #region IDbset

        #region Identity
        public DbSet<MenuEntity> MenuEntities { get; set; }
        public DbSet<LogEntity> LogEntities { get; set; }
        public DbSet<ModuleEntity> ModuleEntities { get; set; }
        #endregion

        #region Vendor

        public IDbSet<VendorEntity> VendorEntities { get; set; }
        public IDbSet<VendorAttachfileEntity> VendorAttachfileEntities { get; set; }
        public IDbSet<ConsultantEntity> ConsultantEntities { get; set; }
        public IDbSet<ConsultantAttachfileEntity> ConsultantAttachfileEntities { get; set; }

        #endregion

        #region Dev

        public IDbSet<BugEntity> BugEntities { get; set; }
        public IDbSet<BugAssignEntity> BugAssignEntities { get; set; }
        public IDbSet<BugAttachfileEntity> BugAttachfileEntities { get; set; }
        public IDbSet<BugHistoryEntity> BugHistoryEntities { get; set; }
        public IDbSet<PublishEntity> PublishEntities { get; set; }

        #endregion

        #endregion

        #region IDbContext

        public new IDbSet<TEntity> Set<TEntity>() where TEntity : class
        {
            return base.Set<TEntity>();
        }

        public void SetAsAdded<TEntity>(TEntity entity) where TEntity : class
        {
            UpdateEntityState(entity, EntityState.Added);
        }

        public void SetAsModified<TEntity>(TEntity entity) where TEntity : class
        {
            UpdateEntityState(entity, EntityState.Modified);
        }

        public void SetAsDeleted<TEntity>(TEntity entity) where TEntity : class
        {
            UpdateEntityState(entity, EntityState.Deleted);
        }

        public void BeginTransaction()
        {
            this.objectContext = ((IObjectContextAdapter) this).ObjectContext;
            if (objectContext.Connection.State == ConnectionState.Closed)
            {
                objectContext.Connection.Open();
            }
            if (transaction == null || transaction.Connection == null)
            {
                transaction = objectContext.Connection.BeginTransaction();
            }
        }

        public int Commit()
        {
            try
            {
                BeginTransaction();
                var saveChanges = SaveChanges();
                transaction.Commit();

                return saveChanges;
            }
            catch (Exception)
            {
                Rollback();
                throw;
            } 
        }

        public void Rollback()
        {
            transaction.Rollback();
        }

        public async Task<int> CommitAsync()
        {
            try
            {
                BeginTransaction();
                var saveChangesAsync = await SaveChangesAsync();
                transaction.Commit();

                return saveChangesAsync;
            }
            catch (Exception)
            {
                Rollback();
                throw;
            } 
        }

        private void UpdateEntityState<TEntity>(TEntity entity, EntityState entityState) where TEntity : class
        {
            var dbEntityEntry = GetDbEntityEntrySafely(entity);
            dbEntityEntry.State = entityState;
        }

        private DbEntityEntry GetDbEntityEntrySafely<TEntity>(TEntity entity) where TEntity : class
        {
            var dbEntityEntry = Entry<TEntity>(entity);
            if (dbEntityEntry.State == EntityState.Detached)
            {
                Set<TEntity>().Attach(entity);
            }
            return dbEntityEntry;
        }

        #endregion

    }

}
